#coding=utf-8
# Copyright (c) 2016 Tinydot. inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.
from PyQt5 import QtWidgets
from PyQt5.QtWidgets import QLineEdit,QDialog,QVBoxLayout,QHBoxLayout,QDialogButtonBox,QLabel,QFileDialog,QProgressBar,QComboBox
from PyQt5.QtCore import Qt
from ml_feature.util.Transfer import graphdef_mlc_json
import os


class PlaceholderDialog(QDialog):
    """
    Placeholder用户输入预设值
    """
    def __init__(self, placeholder_default_value, parent = None):
        super(PlaceholderDialog, self).__init__(parent)
        self.placeholder_default_value = placeholder_default_value

        vlayout = QVBoxLayout(self)
        self.placeholder_set_value = dict()
        for node_name in self.placeholder_default_value.keys():
            hlayout = QHBoxLayout(self)
            name_label = QLabel(node_name, self)
            if self.placeholder_default_value[node_name]:
                value = QLineEdit(str(self.placeholder_default_value[node_name])[1: -1], self)
            else:
                value = QLineEdit('', self)
            self.placeholder_set_value[name_label] = value
            hlayout.addWidget(name_label)
            hlayout.addWidget(value)
            vlayout.addLayout(hlayout)
        # OK and Cancel buttons
        buttons = QDialogButtonBox(
            QDialogButtonBox.Cancel |
            QDialogButtonBox.Ok,
            Qt.Horizontal, self)
        buttons.rejected.connect(self.reject)
        buttons.accepted.connect(self.accept)
        vlayout.addWidget(buttons)
        self.setLayout(vlayout)

    def getValue(self):
        value_dict = dict()
        for label in self.placeholder_set_value.keys():
            value_dict[label.text()] = self.placeholder_set_value[label].text()
        return value_dict

    @staticmethod
    def getPreValue(name,parent=None):
        dialog = PlaceholderDialog(name,parent)
        result = dialog.exec_()
        value = dialog.getValue()
        return value, result==QDialog.Accepted


class Transfer_format(QDialog):
    """
    export对话框，将graph_def转化为伪指令流文件
    """
    def __init__(self, graph_def=None, sequence=None, parent=None):
        super(Transfer_format, self).__init__(parent)
        self.graph_def = graph_def
        self.trans_flag = False
        self.sequence = sequence

        vlayout = QVBoxLayout(self)
        dir_btn = QtWidgets.QPushButton(self)
        dir_btn.setText('select directory')
        dir_btn.clicked.connect(self.select_directory)
        self.save_name_edit = QLineEdit(self)
        self.transer_start_btn = QtWidgets.QPushButton(self)
        self.transer_start_btn.setText('start transfer')
        self.transer_start_btn.clicked.connect(self.transfer)
        vlayout.addWidget(dir_btn)
        vlayout.addWidget(self.save_name_edit)
        vlayout.addWidget(self.transer_start_btn)
        self.setLayout(vlayout)

    def select_directory(self):
        self.save_directory = QFileDialog.getExistingDirectory(self, 'select new model file directory', '~')

    def getTransFlag(self):
        return self.trans_flag

    def transfer(self):
        self.save_name = self.save_name_edit.text()
        self.save_suffix = '.mlc'
        try:
            type(self.save_directory)
        except:
            pass
        else:
            from ml_feature.util import get_mlc_stream_encoder, get_config_mlc_feature
            import traceback
            if self.graph_def and len(self.save_name)!=0:
                trans_func, func_name = get_mlc_stream_encoder()
                mlc_feature_list = get_config_mlc_feature()
                try:
                    json_ordered_dict = graphdef_mlc_json(self.graph_def, mlc_feature_list, self.sequence)
                    url = os.path.abspath(os.path.join(self.save_directory, self.save_name+self.save_suffix))
                    with open(url, 'w+') as f:
                        f.write(trans_func(json_ordered_dict))
                except Exception as e:
                    self.trans_flag = False
                    traceback.print_exc()
                else:
                    self.trans_flag = True
                self.close()


class Save_as(QDialog):
    """
    save_as将graph_def转化为pb或pbtxt文件
    """
    def __init__(self, graph_def=None, parent=None):
        super(Save_as, self).__init__(parent)
        self.graph_def = graph_def
        self.trans_flag = False

        vlayout = QVBoxLayout(self)
        dir_btn = QtWidgets.QPushButton(self)
        dir_btn.setText('select directory')
        dir_btn.clicked.connect(self.select_directory)
        self.save_name_edit = QLineEdit(self)
        self.combox = QComboBox()
        self.combox.addItem('.pb')
        self.combox.addItem('.pbtxt')
        self.combox.setCurrentIndex(0)
        self.transer_start_btn = QtWidgets.QPushButton(self)
        self.transer_start_btn.setText('save')
        self.transer_start_btn.clicked.connect(self.save)
        vlayout.addWidget(dir_btn)
        vlayout.addWidget(self.save_name_edit)
        vlayout.addWidget(self.combox)
        vlayout.addWidget(self.transer_start_btn)
        self.setLayout(vlayout)

    def select_directory(self):
        self.save_directory = QFileDialog.getExistingDirectory(self, 'select new model file directory', '~')

    def getTransFlag(self):
        return self.trans_flag

    def save(self):
        self.save_name = self.save_name_edit.text()
        self.save_suffix = self.combox.currentText()
        try:
            type(self.save_directory)
        except:
            pass
        else:
            from ml_feature.util import get_mlc_stream_encoder
            import traceback
            import tensorflow as tf
            if self.graph_def and len(self.save_name)!=0:
                try:
                    if self.save_suffix == '.pb':
                        as_text = False
                    else:
                        as_text = True
                    tf.train.write_graph(self.graph_def, self.save_directory, self.save_name+self.save_suffix, as_text=as_text)
                except Exception as e:
                    self.trans_flag = False
                    traceback.print_exc()
                else:
                    self.trans_flag = True
                self.close()


class ProgressBarDialog(QDialog):
    """
    主窗口上的进度条
    """
    def __init__(self,time,parent = None):
        super(ProgressBarDialog, self).__init__(parent)
        self.time = time
        self.initUI()

    def initUI(self):
        vlayout = QVBoxLayout(self)
        self.progressText = QLabel('', self)
        self.progressText.setGeometry(0,0,100,30)
        self.progress = QProgressBar(self)
        self.progress.setMaximum(self.time)
        self.progress.setValue(0)
        vlayout.addWidget(self.progressText)
        vlayout.addWidget(self.progress)
        self.setLayout(vlayout)

    def setProgress(self, text, time):
        self.progress.setValue(time)
        self.progressText.setText(text)

    def closeDialog(self):
        self.close()
